#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May  7 17:11:35 2025

@author: ajtulloch
"""

import numpy as np
import matplotlib.pyplot as plt
import json

def load_json(filename):
    """Load data from a JSON file."""
    with open(filename, 'r') as file:
        return json.load(file)

def moving_average(data, window_size):
    """Calculate the moving average using a sliding window approach."""
    window = np.ones(int(window_size)) / float(window_size)
    smoothed_data = np.convolve(data, window, 'same')
    return smoothed_data

def compare_and_plot(control_file, experimental_files, mutation_type, window_size=5, output_filename="Olig2-NR2-plot-5bpsw.png"):
    # Load the control data
    control_data = load_json(control_file)

    # Get the list of positions that are common to both control and experimental data
    control_positions = set(control_data[mutation_type].keys())
    exp_positions = set().union(*(load_json(f)[mutation_type] for f in experimental_files))
    common_positions = sorted(control_positions.intersection(exp_positions), key=int)

    # Initialize lists to hold all normalized experimental data
    normalized_data = {exp_file: [] for exp_file in experimental_files}
    combined_normalized_data = []

    # Collect control means for analysis
    control_means = []

    # Track positions where replicate_values is empty
    empty_positions = []

    # Normalize and log2 transform each experimental point by the control mean for the position
    for pos in common_positions:
        control_mean = control_data[mutation_type][pos]['mean']
        control_means.append(control_mean)
        replicate_values = []
        for exp_file in experimental_files:
            exp_data = load_json(exp_file)
            exp_count = exp_data[mutation_type][pos]['normalized_count']
            normalized_count = exp_count / control_mean if control_mean else np.nan
            # Apply log2 transformation; adding a small constant to avoid log2(0)
            normalized_value = np.log2(normalized_count + 1e-10)
            normalized_data[exp_file].append(normalized_value)
            replicate_values.append(normalized_value)
        
        if replicate_values:
            combined_normalized_data.append(np.nanmean(replicate_values))
        else:
            combined_normalized_data.append(np.nan)
            empty_positions.append(pos)  # Track empty positions

    # Print control mean statistics for analysis
    control_mean_stats = {
        'min': np.min(control_means),
        'max': np.max(control_means),
        'mean': np.mean(control_means),
        'median': np.median(control_means),
        'std': np.std(control_means)
    }
    print("Control Mean Statistics:", control_mean_stats)

    # Print out positions where replicate_values was empty
    if empty_positions:
        print(f"Warning: {len(empty_positions)} positions had empty replicate_values. Positions: {empty_positions}")
    else:
        print("No empty replicate_values found.")

    # Apply baseline adjustment by subtracting the median of the combined normalized data
    baseline_median = np.nanmedian(combined_normalized_data)
    combined_normalized_data = [x - baseline_median for x in combined_normalized_data]

    # Apply baseline adjustment to each experimental dataset
    for exp_file in experimental_files:
        baseline_median_exp = np.nanmedian(normalized_data[exp_file])
        normalized_data[exp_file] = [x - baseline_median_exp for x in normalized_data[exp_file]]

    # Apply sliding window average to the log2-normalized data
    smoothed_data = {
        exp_file: moving_average(np.nan_to_num(normalized_data[exp_file]), window_size)
        for exp_file in experimental_files
    }
    combined_smoothed_data = moving_average(np.nan_to_num(combined_normalized_data), window_size)

    # Exclude the first 18 and last 14 positions
    common_positions = common_positions[18:-14]
    combined_smoothed_data = combined_smoothed_data[18:-14]
    for exp_file in experimental_files:
        smoothed_data[exp_file] = smoothed_data[exp_file][18:-14]

    # Shift the positions so that the first shown position is 0
    shifted_positions = [int(pos) - int(common_positions[0]) for pos in common_positions]

    # Plotting
    plt.figure(figsize=(12, 6))

    # Plot each smoothed and log2-normalized experimental dataset
    for idx, (exp_file, color) in enumerate(zip(experimental_files, plt.cm.viridis(np.linspace(0, 1, len(experimental_files))))):
        plt.plot(shifted_positions, smoothed_data[exp_file], label=f'Sample {idx + 1}', alpha=0.6)

    # Plot the average of all replicates as a bold line
    plt.plot(shifted_positions, combined_smoothed_data, label='Average', color='black', linewidth=2)

    # Set custom x-axis tick marks every 25 positions
    tick_spacing = 25
    plt.xticks(range(0, max(shifted_positions) + 1, tick_spacing))

    # Aesthetics and labels
    plt.xlabel('Position')
    plt.ylabel(f'Log2 Normalized Counts (Baseline Adjusted)')
    plt.legend()
    plt.grid(False)
    plt.tight_layout()

    # Save the plot as a PNG file
    plt.savefig(output_filename)
    
    # Show the plot
    plt.show()

# Replace with your actual file paths
experimental_files = ['/path-to-experimentA/experiment_counts.json',
                      '/path-to-experimentB/experiment_counts.json',
                      '/path-to-experimentC/experiment_counts.json']  # replace with filenames

control_file = '/path-to/output/PlasmidControl/Olig2_NR1/MonteCarlo_Results.json' # replace with filename


window_size = 5  # Define the size of the sliding window
compare_and_plot(control_file, experimental_files, 'substitutions', window_size=5, output_filename="Olig2-NR1-dMPRA.png")
